/workspace/Dropbox/projects/BlainMooersLab/BayesTheseus-PP/clojure_and_pymc5/notebooks/proteins.clj
(ns proteins
  (:require [tablecloth.api :as tc]
            [fastmath.core :as math]
            [fastmath.random :as random]
            [tech.v3.datatype :as dtype]
            [tech.v3.dataset :as dataset]
            [tech.v3.tensor :as tensor]
            [tech.v3.datatype.functional :as fun]
            [aerial.hanami.common :as hc]
            [aerial.hanami.templates :as ht]
            [scicloj.kindly.v3.kind :as kind]
            [scicloj.kindly.v3.api :as kindly]
            [scicloj.clay.v2.api :as clay]
            [libpython-clj2.python :refer [py. py.. py.-] :as py]
            [scicloj.noj.v1.vis :as vis]
            [scicloj.noj.v1.vis.python :as vis.python]
            [libpython-clj2.require :refer [require-python]]
            [util])
  (:import java.lang.Math))
...
(require-python '[builtins :as python]
                'operator
                '[arviz :as az]
                '[arviz.style :as az.style]
                '[pandas :as pd]
                '[matplotlib.pyplot :as plt]
                '[numpy :as np]
                '[numpy.random :as np.random]
                '[pymc :as pm]
                '[Bio.PDB.PDBParser]
                '[Bio.PDB]
                '[Bio.PDB.Polypeptide]
                '[pytensor]
                '[pytensor.tensor :as pt]
                '[math])
:ok
(def protein-name1 "7ju5clean")
...
(def protein-name2 "AF-A0A024R7T2-F1-model_v4-clean")
...
(defn extract-coordinates-from-pdb
  ([protein-name]
   (let [filepath (str "data/" protein-name ".pdb")
         parser (Bio.PDB/PDBParser)
         structure (py. parser get_structure protein-name filepath)]
     (-> structure
         first
         ((fn [model]
            (-> model
                (->> (mapcat
                      (fn [chain]
                        (->> chain
                             (filter (fn [residue]
                                       (-> residue
                                           (py. get_resname)
                                           (Bio.PDB.Polypeptide/is_aa :standard true))))
                             (map (fn [residue]
                                    {:id (-> residue
                                             (py. get_id)
                                             second)
                                     :name (-> residue
                                               (py. get_resname))
                                     :ca-coordinates (try
                                                       (-> residue
                                                           (util/brackets "CA")
                                                           (py. get_coord)
                                                           (->> (dtype/->array :float32)))
                                                       (catch Exception e nil))}))
                             (filter :ca-coordinates))))
                     tc/dataset))))))))
...
(-> protein-name1
    extract-coordinates-from-pdb
    ;; for readability of output:
    (tc/update-columns [:ca-coordinates]
                       (partial map vec)))
...
(defn center-1d [xs]
  (fun/- xs
         (fun/mean xs)))
...
(defn center-columns [xyzs]
  (-> xyzs
      (tensor/map-axis center-1d 0)))
...
(defn read-data
  ([prots]
   (read-data prots nil))
  ([prots {:keys [limit]}]
   (let [prots [protein-name1 protein-name2]
         [dataset1 dataset2] (->> prots
                                  (map extract-coordinates-from-pdb))
         joined-dataset (-> (tc/inner-join dataset1 dataset2 :id)
                            ((if limit
                               #(tc/head % limit)
                               identity)))
         coords (->> [:ca-coordinates :right.ca-coordinates]
                     (map (fn [colname]
                            (-> colname
                                joined-dataset
                                tensor/->tensor))))
         obs (->> coords
                  (mapv #(tensor/map-axis % center-1d 0)))
         obs-datasets (->> obs
                           (mapv util/xyz-tensor->dataset))]
     {:coords coords
      :obs obs
      :obs-datasets obs-datasets})))
...
(-> [protein-name1 protein-name2]
    (read-data {:limit 4})
    :obs-datasets)
...

Compare the datasets visually

(let [{:keys [obs obs-datasets]} (-> [protein-name1 protein-name2]
                                     read-data)
      structures (->> obs
                      (mapv #(-> %
                                 (tensor/transpose [1 0]))))
      view-limit 50
      tensor->cljs (fn [tensor]
                     (-> tensor
                         (tensor/transpose [1 0])
                         util/xyz-tensor->dataset
                         (tc/head view-limit)
                         util/prep-dataset-for-cljs))]
  (->> {:prot1-dataset  (-> structures
                            first
                            tensor->cljs)
        :prot2-dataset (-> structures
                           second
                           tensor->cljs)}
       (vector '(fn [{:keys [prot1-dataset
                             prot2-dataset]}]
                  [plotly
                   {:data [(-> prot1-dataset
                               (merge {:type :scatter3d
                                       :mode :lines+markers
                                       :opacity 1
                                       :marker {:size 3
                                                :color "purple"}}))
                           (-> prot2-dataset
                               (merge {:type :scatter3d
                                       :mode :lines+markers
                                       :opacity 1
                                       :marker {:size 3
                                                :color "orange"}}))]}]))
       kind/hiccup))
...
(defn rotate-q [u]
  (let [theta1 (-> u
                   (util/brackets 1)
                   (operator/mul (* 2 Math/PI)))
        theta2 (-> u
                   (util/brackets 2)
                   (operator/mul (* 2 Math/PI)))
        r1 (-> u
               (util/brackets 0)
               (->> (operator/sub 1))
               pt/sqrt)
        r2 (-> u
               (util/brackets 0)
               pt/sqrt)
        w (-> theta2
              (pt/cos)
              (operator/mul r2))
        x (-> theta1
              (pt/sin)
              (operator/mul r1))
        y (-> theta1
              (pt/cos)
              (operator/mul r1))
        z (-> theta2
              (pt/sin)
              (operator/mul r2))
        R00 (operator/sub (operator/add (pt/sqr w)
                                        (pt/sqr x))
                          (operator/add (pt/sqr y)
                                        (pt/sqr z)))
        R11 (operator/sub (operator/add (pt/sqr w)
                                        (pt/sqr y))
                          (operator/add (pt/sqr x)
                                        (pt/sqr z)))
        R22 (operator/sub (operator/add (pt/sqr w)
                                        (pt/sqr z))
                          (operator/add (pt/sqr x)
                                        (pt/sqr y)))
        R01 (operator/mul 2
                          (operator/sub (operator/mul x y)
                                        (operator/mul w z)))
        R02 (operator/mul 2
                          (operator/add (operator/mul x z)
                                        (operator/mul w y)))
        R10 (operator/mul 2
                          (operator/add (operator/mul x y)
                                        (operator/mul w z)))
        R12 (operator/mul 2
                          (operator/sub (operator/mul y z)
                                        (operator/mul w x)))
        R20 (operator/mul 2
                          (operator/sub (operator/mul x z)
                                        (operator/mul w y)))
        R21 (operator/mul 2
                          (operator/add (operator/mul y z)
                                        (operator/mul w x)))]
    (pt/stack [(pt/stack [R00 R01 R02])
               (pt/stack [R10 R11 R12])
               (pt/stack [R20 R21 R22])])))
...
(defonce model
  (memoize
   (fn [{:keys [residues-limit tune]}]
     (let [{:keys [obs obs-datasets]}
           (read-data [protein-name1 protein-name2]
                      {:limit residues-limit})
           structures (->> obs
                           (mapv #(-> %
                                      (tensor/transpose [1 0]))))
           np-structures (->> structures
                              (mapv util/tensor2d->np-matrix))
           shape (-> (obs 0)
                     dtype/shape
                     reverse
                     vec)
           [space-dimension n-residues] shape]
       (py/with [model (pm/Model)]
                (let [M (pm/Cauchy "M"
                                   :alpha 0
                                   :beta 1
                                   :shape shape)
                      M0 (pm/Deterministic "M0"
                                           (operator/sub
                                            M
                                            (pt/mean M)))
                      t (pm/Normal "t" :shape [space-dimension]) ; the shift
                      u (pm/Uniform "u" :shape [space-dimension]) ; randomization of rotation
                      R (pm/Deterministic "R" (rotate-q u)) ; the rotation matrix
                      U (pm/HalfNormal "U"
                                       :sigma 0.01 ; TODO: Consider some prior here
                                       :shape [n-residues])
                      M0_rotated (pm/Deterministic "M0_rotated"
                                                   (pt/dot R M0))
                      X1 (pm/MatrixNormal "X1"
                                          :mu M0
                                          :rowcov (np/eye space-dimension)
                                          :colcov (pt/diag U)
                                          :observed (np-structures 0))
                      X2 (pm/MatrixNormal "X2"
                                          :mu (-> M0_rotated
                                                  ;; conjugating with transpose
                                                  ;; to make broadcasting work
                                                  pt/transpose
                                                  (operator/add t)
                                                  pt/transpose)
                                          :rowcov (np/eye space-dimension)
                                          :colcov (pt/diag U)
                                          :observed (np-structures 1))
                      M0_adapted (pm/Deterministic "M0_adapted"
                                                   (-> (pt/dot R M0)
                                                       pt/transpose
                                                       (operator/add t)
                                                       pt/transpose))
                      X1_adapted (pm/Deterministic "X1_adapted"
                                                   (-> (pt/dot R X1)
                                                       pt/transpose
                                                       (operator/add t)
                                                       pt/transpose))
                      prot1_adapted (pm/Deterministic "prot1_adapted"
                                                      (-> (np-structures 0)
                                                          (->> (pt/dot R))
                                                          pt/transpose
                                                          (operator/add t)
                                                          pt/transpose))
                      prior-predictive-samples (pm/sample_prior_predictive)
                      idata (pm/sample :chains 1
                                       :draws 200
                                       :tune tune)
                      posterior-predictive-samples (pm/sample_posterior_predictive
                                                    idata)]
                  {:structures structures
                   :prior-predictive-samples prior-predictive-samples
                   :posterior-predictive-samples posterior-predictive-samples
                   :idata idata}))))))
nil
(model {:residues-limit 100 :tune 15})
{:structures [#tech.v3.tensor<object>[3 100]
[[ 9.595  13.34  14.12  16.27  14.92  14.56  13.25   9.845 9.544  7.460 7.627 4.379 4.230 3.901 0.7257 -2.643 -5.457 -8.799 -12.13 -12.71 -13.19  -14.33 -14.17 -14.20 -14.55 -10.74 -9.413 -10.14  -9.008 -8.636 -7.963 -4.995 -2.462  1.236  3.753  6.887  6.023  8.970  10.52  7.208  6.448  3.674 0.9007 0.05573 -3.307 -2.918 -4.943  -3.931 -4.639 -4.793 -6.048 -8.605 -8.173 -9.640 -7.165 -5.966 -2.469 -3.889 -4.879 -1.545 0.3107 -1.723 -0.9393  2.797  2.672   1.035   3.762   6.473  5.075  4.289  7.904  8.951  6.183  5.784  3.301 0.5967 -2.265 -1.028 -1.706  1.451  3.100  4.524  5.649  4.702 4.686  6.312 7.680 4.867 1.573 -0.08527   1.285 0.1167 0.4817 0.5327 -0.6753 -1.898 -5.249 -7.753 -10.45 -13.31]
 [-2.687 -2.544 0.2389 -1.951 -1.808  1.865 0.7629 -0.2181 2.107  5.178 8.275 8.489 12.26 12.06  9.997  11.64  9.197  9.869  8.067  4.432  2.393 -0.6111 -2.453 -5.970 -7.223 -6.831 -4.381 -1.743 -0.2131  3.546  5.844  8.123  10.41  9.954  12.62  11.73  8.043  6.102  9.159  10.46  14.21  15.58  13.37   9.810  8.371  5.007  2.464   1.743 -1.509 -1.445 -3.468 -2.656 -2.940 -6.379 -7.934 -11.50 -12.83 -14.52 -11.13 -9.580 -12.53 -12.27  -8.518 -9.075 -12.06  -9.997  -7.318  -9.942 -11.52 -8.370 -7.280 -9.887 -9.499 -7.046 -7.463 -5.255 -7.364 -7.872 -5.137 -3.198 -2.693 0.6829 0.2549 0.3359 2.527 0.5129 3.328 5.652 4.370    1.272 -0.8891 0.4399 -1.315  1.025 -0.4131 0.5129  2.258 0.7899  3.252 0.8429]
 [ 10.45  9.779  7.300  5.007  1.393 0.5175 -2.854  -1.436 1.561 0.9625 3.280 5.188 4.946 1.154  1.277  1.590  2.179 0.4605 0.3115  1.267 -1.869 0.05551  3.395  4.670  8.245  8.909  6.375  3.817  0.5865 0.3255 -2.515 -2.061 -3.602 -3.314 -3.805 -5.686 -5.875 -7.432 -9.005 -10.52 -10.66 -8.504 -7.160  -8.333 -7.506 -5.651 -3.721 -0.1415  1.696  5.476  8.384  10.97  14.72  15.08  12.63  13.37  12.75  9.671  8.168  9.016  7.442  4.212   3.912  4.453  2.109 -0.6425 -0.2335 -0.4305 -3.533 -5.528 -5.159 -7.697 -10.22 -12.98 -15.73 -17.30 -16.17 -12.69 -10.11 -9.447 -6.053 -5.377 -1.824  1.847 4.911  7.695 9.969 11.05 12.16    10.82   8.016  4.612  1.301 -1.709  -4.991 -8.413 -8.613 -11.04 -11.63 -12.24]]
              #tech.v3.tensor<object>[3 100]
[[ 10.55  10.04  13.18  12.30  12.94  13.04  12.47  9.024  7.993  4.936  3.721 0.7633 -0.8127 -1.138 -3.119 -6.924  -8.353 -11.92  -14.27 -13.15 -12.74 -12.13 -11.09 -10.00 -6.972 -4.765 -5.391 -7.623 -7.801  -9.358 -9.853 -7.826 -6.778 -3.038 -2.058  1.247  1.941  5.457  5.769  2.728 0.3743 -3.187 -4.459 -4.275 -6.934 -5.118 -5.743  -4.195 -3.401 -2.981 -3.281 -6.541 -6.974 -6.376 -3.627 -1.052  2.785  2.354 0.1603  2.746  5.504  3.338  2.592  6.337  7.242  4.666  6.177  9.745 8.662 6.576 9.728 11.46 8.648 7.459 5.473 2.140 0.4083 1.730 0.02830 2.042 3.535  3.325  4.719   4.119  3.606  6.016  6.659  4.389  1.535 0.4353  2.099 0.1283   1.063 -0.06270 -0.8277 -2.638 -6.478 -8.012 -11.52 -13.32]
 [-3.081 -1.371 -2.272 -4.760 -3.423 -6.518 -4.281 -3.242 -6.000 -7.965 -11.38 -10.58  -13.98 -12.68 -9.547 -9.919  -6.873 -5.764  -2.777 0.4731  3.393  6.073  6.443  9.572  11.71  10.12  6.526  3.853  3.087 -0.1568 -1.636 -4.857 -7.564 -8.316 -11.67 -12.02 -8.256 -7.422 -10.98 -10.65 -13.67 -13.07 -10.29 -6.461 -4.194 -1.975 0.4552 -0.3799  2.434  1.365  2.679  2.905 0.5062  3.514  5.214  7.549  7.247  10.48  8.700  5.842  8.441  10.33  6.893  5.893  9.346  8.743  5.234  6.668 9.470 6.969 4.747 7.694 9.300 7.719 10.04 9.491  12.20 10.96   8.031 4.865 2.583 -1.144 -2.485  -2.790 -5.425 -4.180 -7.619 -10.12 -7.678 -3.995 -1.775 -1.959  0.4041   -1.083   1.705  2.442  2.187  5.587  4.380  7.764]
 [-12.86 -9.505 -7.517 -4.742 -1.212  1.113  4.222  2.853 0.3602  1.607 0.2622 -2.110  -1.086  2.548  1.378 0.9962 -0.8328 0.1012 -0.4978 -2.189 0.2842 -2.407 -6.087 -8.047 -9.103 -6.350 -7.602 -5.963 -2.196 -0.8448  2.644  3.026  5.521  5.864  7.487  9.457  9.013  10.39  11.86  14.22  14.49  13.13  10.81  10.61  9.032  6.478  3.624  0.2112 -2.262 -5.916 -9.492 -11.50 -14.48 -16.85 -14.74 -16.33 -15.96 -13.88 -11.27 -11.18 -10.39 -7.841 -6.145 -6.254 -4.792 -2.010 -1.406 -1.178 1.264 3.347 3.779 5.566 7.667 10.97 13.31 15.27  13.12 9.715   7.880 8.681 6.001  6.797  3.465 -0.3008 -3.016 -5.754 -7.391 -9.277 -10.15 -9.788 -7.127 -3.842 -0.9358    2.436   4.969  8.267  8.060  8.939  10.03  9.478]]],
 :prior-predictive-samples Inference data with groups:
	> prior
	> prior_predictive
	> observed_data,
 :posterior-predictive-samples Inference data with groups:
	> posterior_predictive
	> observed_data,
 :idata Inference data with groups:
	> posterior
	> sample_stats
	> observed_data}
(defn show-results [results {:keys [view-limit]}]
  (let [tensor->cljs (fn [tensor aname]
                       (-> tensor
                           (tensor/transpose [1 0])
                           util/xyz-tensor->dataset
                           (tc/head view-limit)
                           util/prep-dataset-for-cljs))
        shape (-> results
                  :idata
                  (py.- posterior)
                  (py.- prot1_adapted)
                  np/shape)
        n-chains (first shape)
        n-samples (second shape)]
    (->> {:prot1-adapted-datasets
          (-> results
              :idata
              (py.- posterior)
              (py.- prot1_adapted)
              util/py-array->clj
              (tensor/slice 1)
              (->> (map-indexed
                    (fn [chain-idx chain-tensor]
                      (-> chain-tensor
                          (tensor/slice 1)
                          (->> (map #(tensor->cljs
                                      %
                                      (str "prot1-adapted-chain"
                                           chain-idx)))))))
                   (apply concat)
                   vec))
          :prot1-chain-idx (->> n-chains
                                range
                                (mapcat (fn [chain-idx]
                                          (repeat n-samples chain-idx)))
                                vec)
          :prot2-dataset
          (-> results
              :structures
              second
              (tensor->cljs "prot2"))}
         (vector '(fn [{:keys [prot1-adapted-datasets
                               prot1-chain-idx
                               prot2-dataset]}]
                    [plotly
                     {:data (->> prot1-adapted-datasets
                                 (map (fn [dataset]
                                        (-> dataset
                                            (merge {:type :scatter3d
                                                    :mode :lines+markers
                                                    :opacity 0.1
                                                    :marker {:size 3
                                                             :color
                                                             (mapv
                                                              ["blue"
                                                               "yellow"
                                                               "red"
                                                               "green"]
                                                              prot1-chain-idx)}}))))
                                 (cons (-> prot2-dataset
                                           (merge {:type :scatter3d
                                                   :mode :lines+markers
                                                   :opacity 1
                                                   :marker {:size 3
                                                            :color "orange"}})))
                                 vec)}]))
         kind/hiccup)))
...
(-> {:residues-limit 100 :tune 200}
    model
    (show-results {:view-limit 50}))
...
(-> {:residues-limit 100 :tune 50}
    model
    (show-results {:view-limit 50}))
...
(-> {:residues-limit 100 :tune 15}
    model
    (show-results {:view-limit 50}))
...
(-> {:residues-limit 100 :tune 5}
    model
    (show-results {:view-limit 50}))
...
:bye
:bye